"""
@file   layers.py
@author Jianfei Guo, Shanghai AI Lab
@brief  Single linear-layers with configuration activation and other properties.
"""

__all__ = [
    'DenseLayer',
    'BatchDenseLayer',
    'get_nonlinearity'
]

import math
import numpy as np
from typing import Callable, NamedTuple, Optional, Union

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

from nr3d_lib.utils import is_scalar, partialclass, torch_dtype

DEBUG_TENSOR = False

#------------------------------------------------------------------------
# NOTE: Some of the codes are borrowed from https://github.com/vsitzmann/siren/modules.py
#------------------------------------------------------------------------

########################
# Initialization methods
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # For PINNet, Raissi et al. 2019
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    # grab from upstream pytorch branch and paste here for now
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def init_weights_trunc_normal(m):
    # For PINNet, Raissi et al. 2019
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    if isinstance(m, (nn.Linear, BatchDenseLayer)):
        if hasattr(m, 'weight'):
            fan_in = m.weight.size(1)
            fan_out = m.weight.size(0)
            std = math.sqrt(2.0 / float(fan_in + fan_out))
            mean = 0.
            # initialize with the same behavior as tf.truncated_normal
            # "The generated values follow a normal distribution with specified mean and
            # standard deviation, except that values whose magnitude is more than 2
            # standard deviations from the mean are dropped and re-picked."
            _no_grad_trunc_normal_(m.weight, mean, std, -2 * std, 2 * std)


def init_weights_normal(m):
    if isinstance(m, (nn.Linear, BatchDenseLayer)):
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')


def init_weights_selu(m):
    if isinstance(m, (nn.Linear, BatchDenseLayer)):
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))


def init_weights_elu(m):
    if isinstance(m, (nn.Linear, BatchDenseLayer)):
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))


def init_weights_xavier(m):
    if isinstance(m, (nn.Linear, BatchDenseLayer)):
        if hasattr(m, 'weight'):
            nn.init.xavier_normal_(m.weight)

class Sine(nn.Module):
    def __init__(self, w0):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)

    def extra_repr(self) -> str:
        return f"w0={self.w0}"

class ClipTanh(nn.Module):
    def __init__(self, bound=1) -> None:
        super().__init__()
        if is_scalar(bound):
            bound = [bound]
        self.register_buffer('bound', torch.tensor(bound))
    def forward(self, x: torch.Tensor):
        return self.bound * torch.tanh(x / self.bound)
    def extra_repr(self) -> str:
        return f"bound={self.bound}"

def _get_sine_init(w0=30.):
    def sine_init(m):
        with torch.no_grad():
            if hasattr(m, 'weight'):
                num_input = m.weight.size(-1)
                std = np.sqrt(6 / num_input) / w0
                # See supplement Sec. 1.5 for discussion of factor 30
                m.weight.uniform_(-std, std)
                if hasattr(m, 'bias') and m.bias is not None:
                    m.bias.uniform_(-std, std)
    return sine_init


def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            std = 1.0 / num_input
            # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-std, std)
            if hasattr(m, 'bias') and m.bias is not None:
                m.bias.uniform_(-std, std)

def nop(m):
    pass

# nonlinearity_ret = namedtuple("namedtuple_nl_gain_init_firstinit", "nl gain init first_init")
class namedtuple_nl_gain_init_firstinit(NamedTuple):
    nl: nn.Module
    gain: float
    init: Callable
    first_init: Callable

# NOTE: Currently, below init functions are not used. They breaks original training.
#       We are using pytorch's original initialization functions. (default behavior)
# nonlinearity_map = {
#     'relu':     namedtuple_nl_gain_init_firstinit(partialclass(nn.ReLU, inplace=True), init_weights_normal, None),
#     'sigmoid':  namedtuple_nl_gain_init_firstinit(nn.Sigmoid, init_weights_xavier, None),
#     'tanh':     namedtuple_nl_gain_init_firstinit(nn.Tanh, init_weights_xavier, None),
#     'selu':     namedtuple_nl_gain_init_firstinit(partialclass(nn.SELU, inplace=True), init_weights_selu, None),
#     'softplus': namedtuple_nl_gain_init_firstinit(nn.Softplus, init_weights_normal, None),
#     'elu':      namedtuple_nl_gain_init_firstinit(partialclass(nn.ELU, inplace=True), init_weights_elu, None)
# }

nonlinearity_map = {
    'relu':         partialclass(nn.ReLU, inplace=True),
    'elu':          partialclass(nn.ELU, inplace=True),
    'selu':         partialclass(nn.SELU, inplace=True),
    'leaky_relu':   partialclass(nn.LeakyReLU, inplace=True),
    'softplus':     nn.Softplus,
    'sigmoid':      nn.Sigmoid,
    'tanh':         nn.Tanh,
    'cliptanh':     ClipTanh
}

def get_nonlinearity(config: Optional[Union[str, dict]], **param):
    if isinstance(config, str):
        if config.lower() == 'none':
            return namedtuple_nl_gain_init_firstinit(None, 1, nop, nop)
        else:
            config = dict(type=config.lower())
    elif config is None:
        return namedtuple_nl_gain_init_firstinit(None, 1, nop, nop)

    if isinstance(config, dict):
        assert 'type' in config, 'You should provide the type of the nonlinearity'
        name = config['type'].lower(); param = {k:v for k,v in config.items() if k != 'type'}
        if name == 'siren':
            param.setdefault('w0', 30.)
            nl = Sine(**param)
            init_fn = _get_sine_init(w0=param['w0'])
            first_init_fn = first_layer_sine_init
            gain = 1
        else:
            # cls, init_fn, first_init_fn = nonlinearity_map[name]
            cls = nonlinearity_map[name]; init_fn = nop; first_init_fn = nop
            nl = cls(**param)
            gain = config.get('gain', None)
            if gain is None:
                if name == 'leaky_relu':
                    gain = init.calculate_gain(name, nl.negative_slope)
                elif name == 'softplus':
                    if nl.beta >= 5.0: # When beta is bigger than 5., the softplus is pretty much approx to ReLU
                        gain = math.sqrt(2)
                    else:
                        gain = 1.
                elif name in ['cliptanh']:
                    gain = 1.
                else:
                    gain = init.calculate_gain(name)
        if first_init_fn is None:
            first_init_fn = init_fn
        nl.name = name
        return namedtuple_nl_gain_init_firstinit(nl, gain, init_fn, first_init_fn)
    else:
        raise RuntimeError(f"Invalid nonlinearity={config}")


class DenseLayer(nn.Module):
    def __init__(
        self, 
        in_features: int, out_features: int, *,
        bias: bool=True, activation: Union[str, dict, nn.Module]=None, should_init=True, 
        # For equal lr
        equal_lr=False, lr_mul: float=1, weight_init: float=1, bias_init: float=1, 
        # Factory_kwargs
        dtype: Union[str, torch.dtype]=torch.float, device: torch.device=torch.device('cuda'), ):
        """ Construct a dense layer

        Args:
            in_features (int): Input feature width.
            out_features (int): Output feature width.
            bias (bool, optional): Whether hidden layers have bias. Defaults to True.
            activation (Union[str, dict, nn.Module], optional): Can be a name, a config dict, or a nn.Module. Defaults to None.
            should_init (bool, optional): Whether conduct dense layer initialization. Defaults to True.
            equal_lr (bool, optional): Whether use `equal_lr`. Defaults to False.
            lr_mul (float, optional): equal_lr's multiplication factor. Defaults to 1.
            weight_init (float, optional): weight init mul factor. Defaults to 1.
            bias_init (float, optional): bias init mul factor. Defaults to 1.
            dtype (Union[str, torch.dtype], optional): Network param's dtype. Defaults to torch.float.
                Can be a string (e.g. "float", "half") or a torch.dtype. 
            device (torch.device, optional): Network param's device. Defaults to None. Defaults to torch.device('cuda').
        """

        super().__init__()
        self.dtype = torch_dtype(dtype)
        self.device = device
        self.in_features = in_features
        self.out_features = out_features
        self.equal_lr = equal_lr
        
        # NOTE: Parameters should always be stored as float32. `self.dtype` is only respected when forward.
        self.weight = nn.Parameter(torch.empty((out_features, in_features), device=self.device, dtype=torch.float))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features, device=self.device, dtype=torch.float))
        else:
            self.register_parameter('bias', None)
        
        if isinstance(activation, (str, dict)):
            activation = get_nonlinearity(activation).nl
        self.activation = activation

        # Init parameters
        if should_init:
            if equal_lr:
                bound = weight_init / lr_mul
                init.uniform_(self.weight, -bound, bound)
                self.weight_gain = lr_mul / np.sqrt(self.in_features)
            else: # NOTE: pytorch's original nn.Linear's reset_parameters() function
                # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
                # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
                # https://github.com/pytorch/pytorch/issues/57109

                # bound = 1 / math.sqrt(self.in_features) * (weight_init / lr_mul)
                # init.uniform_(self.weight, -bound, bound)
                init.kaiming_uniform_(self.weight, a=math.sqrt(5))
                with torch.no_grad(): 
                    self.weight *= (weight_init / lr_mul)
                self.weight_gain = lr_mul

            if self.bias is not None:
                fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
                bound = (1 / math.sqrt(fan_in) if fan_in > 0 else 0) * (bias_init / lr_mul)
                init.uniform_(self.bias, -bound, bound)
                self.bias_gain = lr_mul

    def get_weight_reg(self, norm_type: float = 2.0):
        norms = [self.weight.norm(p=norm_type)]
        if self.bias is not None:
            norms.append(self.bias.norm(p=norm_type))
        return torch.cat(norms)

    def forward(self, x: torch.Tensor, max_channel: int = None):
        with torch.autocast(device_type='cuda', dtype=self.dtype):
            weight = self.weight[:, :max_channel] if max_channel is not None else self.weight
            bias = self.bias[:max_channel] if (max_channel is not None and self.bias is not None) else self.bias
            if (self.weight_gain == 1) and (self.bias is None or self.bias_gain == 1):
                torch.cuda.synchronize()
                out = F.linear(x, weight, bias)
            else:
                out = F.linear(x, weight * self.weight_gain, bias * self.bias_gain)
            if self.activation is not None:
                out = self.activation(out)
            return out

    def extra_repr(self) -> str:
        return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, equal_lr={self.equal_lr}"

    @classmethod
    def from_params(
        cls,
        weight_param: torch.Tensor, bias_param: torch.Tensor = None, 
        activation: Union[str, dict, nn.Module]=None, 
        # for equal lr
        equal_lr=False, lr_mul: float=1, 
        # factory_kwargs
        dtype: Union[str, torch.dtype]=None, device: torch.device=None):
        """ Construct a dense layer directly given weight and bias parameter.

        Args:
            weight_param (torch.Tensor): The given weight parameter
            bias_param (torch.Tensor, optional): The given bias parameter. Defaults to None.
            activation (Union[str, dict, nn.Module], optional): Can be a name, a config dict, or a nn.Module. Defaults to None.
            equal_lr (bool, optional): Whether use `equal_lr`. Defaults to False.
            lr_mul (float, optional): equal_lr's multiplication factor. Defaults to 1.
            dtype (Union[str, torch.dtype], optional): Network param's dtype. Defaults to None.
            device (torch.device, optional): Network param's device. Defaults to None.

        Returns:
            DenseLayer: The constructed DenseLayer instance
        """
        out_features, in_features = weight_param.shape
        o = cls(in_features, out_features, activation=activation)
        o.weight = weight_param.to(dtype=torch.float, device=device)
        o.bias = bias_param.to(dtype=torch.float, device=device)
        o.equal_lr = equal_lr
        o.device = device
        o.dtype = dtype

        if equal_lr:
            o.weight_gain = lr_mul / np.sqrt(o.in_features)
            o.bias_gain = lr_mul
        else:
            o.weight_gain = lr_mul
            o.bias_gain = lr_mul
        return o

class BatchDenseLayer(nn.Module):
    # NOTE: currently, only support _weight_init
    def __init__(self, weight: torch.Tensor, bias: torch.Tensor, activation: Union[str, dict]=None):
        """ Batched dense layer (Multiple dense layers in a batch)
            NOTE: `...` means arbitary prefix dim.
            Given multiple weights [..., out_ch, in_ch] and biases [..., out_ch], 
            construct multiple dense layers, with inputs feeding into each dense layer
            
            forward: `...` is the same prefix-batch-dims of weights and biases
                input: [..., in_ch]
                output: [..., out_ch]

        Args:
            weight (torch.Tensor): [..., out_ch, in_ch], Multiple dense layers' weights
            bias (torch.Tensor): [..., out_ch], Multiple dense layers' bias
            activation (Union[str, dict], optional): A name or config dict of activation. Defaults to None.
        """
        super().__init__()

        if bias is not None:
            assert [*weight.shape[:-2]] == [*bias.shape[:-1]], f'weight and bias should have the same prefix shape, current: {[*weight.shape[:-2]]}, {[*bias.shape[:-1]]}'
            assert bias.shape[-1] == weight.shape[-2], f'bias should have shape suffix = {weight.shape[-2]}, but current is {bias.shape[-1]}'
        self.weight = weight
        self.bias = bias
        self.prefix = [*weight.shape[:-2]]
        
        if isinstance(activation, (str, dict)):
            activation = get_nonlinearity(activation).nl
        self.activation = activation

    def __repr__(self):
        return f"BatchDenseLayer(in_dim={self.weight.shape[-1]}, out_dim={self.weight.shape[-2]}, prefix_shape={self.prefix}, activation={self.activation})"

    def forward(self, x):
        assert [*x.shape[:len(self.prefix)]] == self.prefix
        out = x.unsqueeze(-2).matmul(self.weight.transpose(-1, -2)).squeeze(-2)
        if self.bias is not None:
            out += self.bias
        if DEBUG_TENSOR: print(f"{[*x.shape]} @ {[*self.weight.shape]}.t + {[*self.bias.shape]} = {[*out.shape]}" )
        if self.activation is not None:
            out = self.activation(out)
        return out

if __name__ == "__main__":
    def unit_test_batched_dense_layer(device=torch.device('cuda'), dtype=torch.float):
        global DEBUG_TENSOR
        DEBUG_TENSOR = True
        
        from icecream import ic
        prefix = [7,13]
        in_ch = 16
        out_ch = 3
        weight = torch.randn(*prefix,out_ch,in_ch, device=device, dtype=dtype)
        bias = torch.randn(*prefix,out_ch, device=device, dtype=dtype)
        m = BatchDenseLayer(weight, bias, activation={'type':'softplus', 'beta':100.0})
        ic(m)
        
        x = torch.randn(*prefix, in_ch, device=device, dtype=dtype)
        y = m(x)
        ic(y.shape, y.dtype)
        
    unit_test_batched_dense_layer()